import pandas as pd
import numpy as np
from transformers import pipeline, logging
from datasets import load_dataset
import time
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error() 

ds = load_dataset('mlburnham/Pol_NLI')
test = ds['test'].to_pandas()
ndocs = 5000
test = test.sample(ndocs, random_state = 1)
timings = []
docs = list(test['premise'])

def time_it(pipe, docs, n=1):
    """
    Benchmark the time taken by a model pipeline.

    Args:
        pipe (callable): The model pipeline to benchmark.
        docs (list): A list of documents that will be passed to the pipe.
        n (int): Number of times to run the benchmark. Defaults to 1.

    Returns:
        dict: A dictionary containing the model name, hardware, average time, DPS (documents per second),
              and standard errors for both metrics if n > 1.
    """
    times = []

    for i in range(n):
        # Start the timer
        start_time = time.time()
        results = pipe(docs, 'This text is about politics.', hypothesis_template='{}')
        # Stop timer
        end_time = time.time()
        # Calculate the elapsed time for this run
        elapsed_time = end_time - start_time
        times.append(elapsed_time)

        print(f"Run {i + 1}/{n} - Elapsed time: {elapsed_time:.2f} seconds")

    # Calculate the average time and DPS
    avg_time = np.mean(times)
    dps = ndocs / avg_time

    # Calculate standard errors if n > 1
    if n > 1:
        time_se = np.std(times, ddof=1) / np.sqrt(n)
        dps_se = (np.std([ndocs / t for t in times], ddof=1) / np.sqrt(n))
    else:
        time_se = None
        dps_se = None

    print(f"Average elapsed time: {avg_time:.2f} seconds")
    print(f"Average DPS: {dps}")

    if n > 1:
        print(f"Standard error (Time): {time_se:.4f} seconds")
        print(f"Standard error (DPS): {dps_se:.4f}")

    torch.cuda.empty_cache()

    res = {
        'Model': model.split('/')[-1],
        'Hardware': pipe.device.type,
        'Time': avg_time,
        'Time_SE': time_se,
        'DPS': dps,
        'DPS_SE': dps_se
    }
    return res

#######################
## DEBATE Base DeBERTa
#######################
model = "mlburnham/Political_DEBATE_DeBERTa_base_v1.1"
pipe = pipeline("zero-shot-classification", model = model, device = 'cpu', batch_size = 8, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

#######################
## DEBATE Large DeBERTa
#######################
model = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"
pipe = pipeline("zero-shot-classification", model = model, device = 'cpu', batch_size = 8, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

# export
cpu = pd.DataFrame(timings)
cpu.to_csv('../data/cpu_timings.csv', index = False)